{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datalab: Advanced workflows to audit your data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cleanlab offers a `Datalab` object to identify various issues in your machine learning datasets that may negatively impact models if not addressed. By default, `Datalab` can help you identify noisy labels, outliers, (near) duplicates, and other types of problems that commonly occur in real-world data.\n",
    "\n",
    "`Datalab` performs these checks by utilizing the (probabilistic) predictions from *any* ML model that has already been trained or its learned representations of the data. Underneath the hood, this class calls all the appropriate cleanlab methods for your dataset and provided model outputs, in order to best audit the data and alert you of important issues. This makes it easy to apply many functionalities of this library all within a single line of code. \n",
    "\n",
    "**This tutorial will demonstrate some advanced functionalities of Datalab including:**\n",
    "\n",
    "- Incremental issue search\n",
    "- Specifying nondefault arguments to issue checks\n",
    "- Save and load Datalab objects\n",
    "- Adding a custom IssueManager\n",
    "\n",
    "If you are new to `Datalab`, check out this [quickstart tutorial](datalab_quickstart.html) for a 5-min introduction!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "Quickstart\n",
    "<br/>\n",
    "    \n",
    "Already have (out-of-sample) `pred_probs` from a model trained on an existing set of labels? Maybe you have some `features` as well? Run the code below to examine your dataset for multiple types of issues.\n",
    "\n",
    "<div  class=markdown markdown=\"1\" style=\"background:white;margin:16px\">  \n",
    "    \n",
    "```ipython3 \n",
    "from cleanlab import Datalab\n",
    "\n",
    "lab = Datalab(data=your_dataset, label_name=\"column_name_of_labels\")\n",
    "lab.find_issues(features=your_feature_matrix, pred_probs=your_pred_probs)\n",
    "\n",
    "lab.report()\n",
    "```\n",
    "   \n",
    "</div>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install and import required dependencies"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Datalab` has additional dependencies that are not included in the standard installation of cleanlab.\n",
    "\n",
    "You can use pip to install all packages required for this tutorial as follows:\n",
    "\n",
    "```ipython3\n",
    "!pip install matplotlib \n",
    "!pip install \"cleanlab[datalab]\"\n",
    "\n",
    "# Make sure to install the version corresponding to this tutorial\n",
    "# E.g. if viewing master branch documentation:\n",
    "#     !pip install git+https://github.com/cleanlab/cleanlab.git\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Package installation (hidden on docs website).\n",
    "dependencies = [\"cleanlab\", \"matplotlib\", \"datasets\"]  # TODO: make sure this list is updated\n",
    "\n",
    "if \"google.colab\" in str(get_ipython()):  # Check if it's running in Google Colab\n",
    "    %pip install cleanlab  # for colab\n",
    "    cmd = ' '.join([dep for dep in dependencies if dep != \"cleanlab\"])\n",
    "    %pip install $cmd\n",
    "else:\n",
    "    dependencies_test = [dependency.split('>')[0] if '>' in dependency \n",
    "                         else dependency.split('<')[0] if '<' in dependency \n",
    "                         else dependency.split('=')[0] for dependency in dependencies]\n",
    "    missing_dependencies = []\n",
    "    for dependency in dependencies_test:\n",
    "        try:\n",
    "            __import__(dependency)\n",
    "        except ImportError:\n",
    "            missing_dependencies.append(dependency)\n",
    "\n",
    "    if len(missing_dependencies) > 0:\n",
    "        print(\"Missing required dependencies:\")\n",
    "        print(*missing_dependencies, sep=\", \")\n",
    "        print(\"\\nPlease install them before running the rest of this notebook.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "\n",
    "from cleanlab import Datalab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create and load the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll load a toy classification dataset for this tutorial. The dataset has two numerical features and a label column with three classes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<details><summary>See the code for data generation. **(click to expand)**</summary>\n",
    "    \n",
    "```ipython3\n",
    "# Note: This pulldown content is for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from cleanlab.benchmarking.noise_generation import (\n",
    "    generate_noise_matrix_from_trace,\n",
    "    generate_noisy_labels,\n",
    ")\n",
    "\n",
    "SEED = 123\n",
    "np.random.seed(SEED)\n",
    "\n",
    "BINS = {\n",
    "    \"low\": [-np.inf, 3.3],\n",
    "    \"mid\": [3.3, 6.6],\n",
    "    \"high\": [6.6, +np.inf],\n",
    "}\n",
    "\n",
    "BINS_MAP = {\n",
    "    \"low\": 0,\n",
    "    \"mid\": 1,\n",
    "    \"high\": 2,\n",
    "}\n",
    "\n",
    "\n",
    "def create_data():\n",
    "\n",
    "    X = np.random.rand(250, 2) * 5\n",
    "    y = np.sum(X, axis=1)\n",
    "    # Map y to bins based on the BINS dict\n",
    "    y_bin = np.array([k for y_i in y for k, v in BINS.items() if v[0] <= y_i < v[1]])\n",
    "    y_bin_idx = np.array([BINS_MAP[k] for k in y_bin])\n",
    "\n",
    "    # Split into train and test\n",
    "    X_train, X_test, y_train, y_test, y_train_idx, y_test_idx = train_test_split(\n",
    "        X, y_bin, y_bin_idx, test_size=0.5, random_state=SEED\n",
    "    )\n",
    "\n",
    "    # Add several (5) out-of-distribution points. Sliding them along the decision boundaries\n",
    "    # to make them look like they are out-of-frame\n",
    "    X_out = np.array(\n",
    "        [\n",
    "            [-1.5, 3.0],\n",
    "            [-1.75, 6.5],\n",
    "            [1.5, 7.2],\n",
    "            [2.5, -2.0],\n",
    "            [5.5, 7.0],\n",
    "        ]\n",
    "    )\n",
    "    # Add a near duplicate point to the last outlier, with some tiny noise added\n",
    "    near_duplicate = X_out[-1:] + np.random.rand(1, 2) * 1e-6\n",
    "    X_out = np.concatenate([X_out, near_duplicate])\n",
    "\n",
    "    y_out = np.sum(X_out, axis=1)\n",
    "    y_out_bin = np.array([k for y_i in y_out for k, v in BINS.items() if v[0] <= y_i < v[1]])\n",
    "    y_out_bin_idx = np.array([BINS_MAP[k] for k in y_out_bin])\n",
    "\n",
    "    # Add to train\n",
    "    X_train = np.concatenate([X_train, X_out])\n",
    "    y_train = np.concatenate([y_train, y_out])\n",
    "    y_train_idx = np.concatenate([y_train_idx, y_out_bin_idx])\n",
    "\n",
    "    # Add an exact duplicate example to the training set\n",
    "    exact_duplicate_idx = np.random.randint(0, len(X_train))\n",
    "    X_duplicate = X_train[exact_duplicate_idx, None]\n",
    "    y_duplicate = y_train[exact_duplicate_idx, None]\n",
    "    y_duplicate_idx = y_train_idx[exact_duplicate_idx, None]\n",
    "\n",
    "    # Add to train\n",
    "    X_train = np.concatenate([X_train, X_duplicate])\n",
    "    y_train = np.concatenate([y_train, y_duplicate])\n",
    "    y_train_idx = np.concatenate([y_train_idx, y_duplicate_idx])\n",
    "\n",
    "    py = np.bincount(y_train_idx) / float(len(y_train_idx))\n",
    "    m = len(BINS)\n",
    "\n",
    "    noise_matrix = generate_noise_matrix_from_trace(\n",
    "        m,\n",
    "        trace=0.9 * m,\n",
    "        py=py,\n",
    "        valid_noise_matrix=True,\n",
    "        seed=SEED,\n",
    "    )\n",
    "\n",
    "    noisy_labels_idx = generate_noisy_labels(y_train_idx, noise_matrix)\n",
    "    noisy_labels = np.array([list(BINS_MAP.keys())[i] for i in noisy_labels_idx])\n",
    "\n",
    "    return X_train, y_train_idx, noisy_labels, noisy_labels_idx, X_out, X_duplicate\n",
    "```\n",
    "\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from cleanlab.benchmarking.noise_generation import (\n",
    "    generate_noise_matrix_from_trace,\n",
    "    generate_noisy_labels,\n",
    ")\n",
    "\n",
    "SEED = 123\n",
    "np.random.seed(SEED)\n",
    "\n",
    "BINS = {\n",
    "    \"low\": [-np.inf, 3.3],\n",
    "    \"mid\": [3.3, 6.6],\n",
    "    \"high\": [6.6, +np.inf],\n",
    "}\n",
    "\n",
    "BINS_MAP = {\n",
    "    \"low\": 0,\n",
    "    \"mid\": 1,\n",
    "    \"high\": 2,\n",
    "}\n",
    "\n",
    "\n",
    "def create_data():\n",
    "\n",
    "    X = np.random.rand(250, 2) * 5\n",
    "    y = np.sum(X, axis=1)\n",
    "    # Map y to bins based on the BINS dict\n",
    "    y_bin = np.array([k for y_i in y for k, v in BINS.items() if v[0] <= y_i < v[1]])\n",
    "    y_bin_idx = np.array([BINS_MAP[k] for k in y_bin])\n",
    "\n",
    "    # Split into train and test\n",
    "    X_train, X_test, y_train, y_test, y_train_idx, y_test_idx = train_test_split(\n",
    "        X, y_bin, y_bin_idx, test_size=0.5, random_state=SEED\n",
    "    )\n",
    "\n",
    "    # Add several (5) out-of-distribution points. Sliding them along the decision boundaries\n",
    "    # to make them look like they are out-of-frame\n",
    "    X_out = np.array(\n",
    "        [\n",
    "            [-1.5, 3.0],\n",
    "            [-1.75, 6.5],\n",
    "            [1.5, 7.2],\n",
    "            [2.5, -2.0],\n",
    "            [5.5, 7.0],\n",
    "        ]\n",
    "    )\n",
    "    # Add a near duplicate point to the last outlier, with some tiny noise added\n",
    "    near_duplicate = X_out[-1:] + np.random.rand(1, 2) * 1e-6\n",
    "    X_out = np.concatenate([X_out, near_duplicate])\n",
    "\n",
    "    y_out = np.sum(X_out, axis=1)\n",
    "    y_out_bin = np.array([k for y_i in y_out for k, v in BINS.items() if v[0] <= y_i < v[1]])\n",
    "    y_out_bin_idx = np.array([BINS_MAP[k] for k in y_out_bin])\n",
    "\n",
    "    # Add to train\n",
    "    X_train = np.concatenate([X_train, X_out])\n",
    "    y_train = np.concatenate([y_train, y_out])\n",
    "    y_train_idx = np.concatenate([y_train_idx, y_out_bin_idx])\n",
    "\n",
    "    # Add an exact duplicate example to the training set\n",
    "    exact_duplicate_idx = np.random.randint(0, len(X_train))\n",
    "    X_duplicate = X_train[exact_duplicate_idx, None]\n",
    "    y_duplicate = y_train[exact_duplicate_idx, None]\n",
    "    y_duplicate_idx = y_train_idx[exact_duplicate_idx, None]\n",
    "\n",
    "    # Add to train\n",
    "    X_train = np.concatenate([X_train, X_duplicate])\n",
    "    y_train = np.concatenate([y_train, y_duplicate])\n",
    "    y_train_idx = np.concatenate([y_train_idx, y_duplicate_idx])\n",
    "\n",
    "    py = np.bincount(y_train_idx) / float(len(y_train_idx))\n",
    "    m = len(BINS)\n",
    "\n",
    "    noise_matrix = generate_noise_matrix_from_trace(\n",
    "        m,\n",
    "        trace=0.9 * m,\n",
    "        py=py,\n",
    "        valid_noise_matrix=True,\n",
    "        seed=SEED,\n",
    "    )\n",
    "\n",
    "    noisy_labels_idx = generate_noisy_labels(y_train_idx, noise_matrix)\n",
    "    noisy_labels = np.array([list(BINS_MAP.keys())[i] for i in noisy_labels_idx])\n",
    "\n",
    "    return X_train, y_train_idx, noisy_labels, noisy_labels_idx, X_out, X_duplicate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, y_train_idx, noisy_labels, noisy_labels_idx, X_out, X_duplicate = create_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We make a scatter plot of the features, with a color corresponding to the observed labels. Incorrect given labels are highlighted in red if they do not match the true label, outliers highlighted with an a black cross, and duplicates highlighted with a cyan cross."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<details><summary>See the code to visualize the data. **(click to expand)**</summary>\n",
    "    \n",
    "```ipython3\n",
    "# Note: This pulldown content is for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_data(X_train, y_train_idx, noisy_labels_idx, X_out, X_duplicate):\n",
    "    # Plot data with clean labels and noisy labels, use BINS_MAP for the legend\n",
    "    fig, ax = plt.subplots(figsize=(8, 6.5))\n",
    "        \n",
    "    low = ax.scatter(X_train[noisy_labels_idx == 0, 0], X_train[noisy_labels_idx == 0, 1], label=\"low\")\n",
    "    mid = ax.scatter(X_train[noisy_labels_idx == 1, 0], X_train[noisy_labels_idx == 1, 1], label=\"mid\")\n",
    "    high = ax.scatter(X_train[noisy_labels_idx == 2, 0], X_train[noisy_labels_idx == 2, 1], label=\"high\")\n",
    "    \n",
    "    ax.set_title(\"Noisy labels\")\n",
    "    ax.set_xlabel(r\"$x_1$\", fontsize=16)\n",
    "    ax.set_ylabel(r\"$x_2$\", fontsize=16)\n",
    "\n",
    "    # Plot true boundaries (x+y=3.3, x+y=6.6)\n",
    "    ax.set_xlim(-3.5, 9.0)\n",
    "    ax.set_ylim(-3.5, 9.0)\n",
    "    ax.plot([-0.7, 4.0], [4.0, -0.7], color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "    ax.plot([-0.7, 7.3], [7.3, -0.7], color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "    # Draw red circles around the points that are misclassified (i.e. the points that are in the wrong bin)\n",
    "    for i, (X, y) in enumerate(zip([X_train, X_train], [y_train_idx, noisy_labels_idx])):\n",
    "        for j, (k, v) in enumerate(BINS_MAP.items()):\n",
    "            label_err = ax.scatter(\n",
    "                X[(y == v) & (y != y_train_idx), 0],\n",
    "                X[(y == v) & (y != y_train_idx), 1],\n",
    "                s=180,\n",
    "                marker=\"o\",\n",
    "                facecolor=\"none\",\n",
    "                edgecolors=\"red\",\n",
    "                linewidths=2.5,\n",
    "                alpha=0.5,\n",
    "                label=\"Label error\",\n",
    "            )\n",
    "\n",
    "\n",
    "    outlier = ax.scatter(X_out[:, 0], X_out[:, 1], color=\"k\", marker=\"x\", s=100, linewidth=2, label=\"Outlier\")\n",
    "\n",
    "    # Plot the exact duplicate\n",
    "    dups = ax.scatter(\n",
    "        X_duplicate[:, 0],\n",
    "        X_duplicate[:, 1],\n",
    "        color=\"c\",\n",
    "        marker=\"x\",\n",
    "        s=100,\n",
    "        linewidth=2,\n",
    "        label=\"Duplicates\",\n",
    "    )\n",
    "    \n",
    "    first_legend = ax.legend(handles=[low, mid, high], loc=[0.75, 0.7], title=\"Given Class Label\", alignment=\"left\", title_fontproperties={\"weight\":\"semibold\"})\n",
    "    second_legend = ax.legend(handles=[label_err, outlier, dups], loc=[0.75, 0.45], title=\"Type of Issue\", alignment=\"left\", title_fontproperties={\"weight\":\"semibold\"})\n",
    "    \n",
    "    ax = plt.gca().add_artist(first_legend)\n",
    "    ax = plt.gca().add_artist(second_legend)\n",
    "    plt.tight_layout()\n",
    "```\n",
    "    \n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_data(X_train, y_train_idx, noisy_labels_idx, X_out, X_duplicate):\n",
    "    # Plot data with clean labels and noisy labels, use BINS_MAP for the legend\n",
    "    fig, ax = plt.subplots(figsize=(6, 4))\n",
    "        \n",
    "    low = ax.scatter(X_train[noisy_labels_idx == 0, 0], X_train[noisy_labels_idx == 0, 1], label=\"low\")\n",
    "    mid = ax.scatter(X_train[noisy_labels_idx == 1, 0], X_train[noisy_labels_idx == 1, 1], label=\"mid\")\n",
    "    high = ax.scatter(X_train[noisy_labels_idx == 2, 0], X_train[noisy_labels_idx == 2, 1], label=\"high\")\n",
    "    \n",
    "    ax.set_title(\"Noisy labels\")\n",
    "    ax.set_xlabel(r\"$x_1$\", fontsize=16)\n",
    "    ax.set_ylabel(r\"$x_2$\", fontsize=16)\n",
    "\n",
    "    # Plot true boundaries (x+y=3.3, x+y=6.6)\n",
    "    ax.set_xlim(-2.5, 8.5)\n",
    "    ax.set_ylim(-3.5, 9.0)\n",
    "    ax.plot([-0.7, 4.0], [4.0, -0.7], color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "    ax.plot([-0.7, 7.3], [7.3, -0.7], color=\"k\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "    # Draw red circles around the points that are misclassified (i.e. the points that are in the wrong bin)\n",
    "    for i, (X, y) in enumerate(zip([X_train, X_train], [y_train_idx, noisy_labels_idx])):\n",
    "        for j, (k, v) in enumerate(BINS_MAP.items()):\n",
    "            label_err = ax.scatter(\n",
    "                X[(y == v) & (y != y_train_idx), 0],\n",
    "                X[(y == v) & (y != y_train_idx), 1],\n",
    "                s=180,\n",
    "                marker=\"o\",\n",
    "                facecolor=\"none\",\n",
    "                edgecolors=\"red\",\n",
    "                linewidths=2.5,\n",
    "                alpha=0.5,\n",
    "                label=\"Label error\",\n",
    "            )\n",
    "\n",
    "\n",
    "    outlier = ax.scatter(X_out[:, 0], X_out[:, 1], color=\"k\", marker=\"x\", s=100, linewidth=2, label=\"Outlier\")\n",
    "\n",
    "    # Plot the exact duplicate\n",
    "    dups = ax.scatter(\n",
    "        X_duplicate[:, 0],\n",
    "        X_duplicate[:, 1],\n",
    "        color=\"c\",\n",
    "        marker=\"x\",\n",
    "        s=100,\n",
    "        linewidth=2,\n",
    "        label=\"Duplicates\",\n",
    "    )\n",
    "    \n",
    "    title_fontproperties = {\"weight\":\"semibold\", \"size\": 8}\n",
    "    first_legend = ax.legend(handles=[low, mid, high], loc=[0.76, 0.7], title=\"Given Class Label\", alignment=\"left\", title_fontproperties=title_fontproperties, fontsize=8, markerscale=0.5)\n",
    "    second_legend = ax.legend(handles=[label_err, outlier, dups], loc=[0.76, 0.46], title=\"Type of Issue\", alignment=\"left\", title_fontproperties=title_fontproperties, fontsize=8, markerscale=0.5)\n",
    "    \n",
    "    ax = plt.gca().add_artist(first_legend)\n",
    "    ax = plt.gca().add_artist(second_legend)\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_data(X_train, y_train_idx, noisy_labels_idx, X_out, X_duplicate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In real-world scenarios, you won't know the true labels or the distribution of the features, so we won't use these in this tutorial, except for evaluation purposes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get out-of-sample predicted probabilities from a classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To detect certain types of issues in classification data (e.g. label errors), `Datalab` relies on predicted class probabilities from a trained model. Ideally, the prediction for each example should be out-of-sample (to avoid overfitting), coming from a copy of the model that was not trained on this example. \n",
    "\n",
    "This tutorial uses a simple logistic regression model \n",
    "and the `cross_val_predict()` function from scikit-learn to generate out-of-sample predicted class probabilities for every example in the training set. You can replace this with *any* other classifier model and train it with cross-validation to get out-of-sample predictions.\n",
    "Make sure that the columns of your `pred_probs` are properly ordered with respect to the ordering of classes, which for Datalab is: lexicographically sorted by class name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogisticRegression()\n",
    "pred_probs = cross_val_predict(\n",
    "    estimator=model, X=X_train, y=noisy_labels, cv=5, method=\"predict_proba\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instantiate Datalab object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we instantiate the Datalab object that will be used in the remainder in the tutorial by passing in the data created above.\n",
    "\n",
    "`Datalab` has several ways of loading the data. In this case, we'll simply wrap the training features and noisy labels in a dictionary so that we can pass it to `Datalab`.\n",
    "\n",
    "Other supported data formats for `Datalab` include: [HuggingFace Datasets](https://huggingface.co/docs/datasets/index) and [pandas DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html). `Datalab` works across most data modalities (image, text, tabular, audio, etc). It is intended to find issues that commonly occur in datasets for which you have trained a supervised ML model, regardless of the type of data.\n",
    "\n",
    "Currently, pandas DataFrames that contain categorical columns might cause some issues when instantiating the `Datalab` object, so it is recommended to ensure that your DataFrame does not contain any categorical columns, or use other data formats (eg. python dictionary, HuggingFace Datasets) to pass in your data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\"X\": X_train, \"y\": noisy_labels}\n",
    "\n",
    "lab = Datalab(data, label_name=\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Functionality 1**: Incremental issue search "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can call `find_issues` multiple times on a `Datalab` object to detect issues one type at a time.\n",
    "\n",
    "This is done via the `issue_types` argument which accepts a dictionary of issue types and any corresponding keyword arguments to specify nondefault keyword arguments to use for detecting each type of issues. In this first call, we only want to detect label issues, which are detected solely based on `pred_probs`, hence there is no need for us to pass in `features` here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab.find_issues(pred_probs=pred_probs, issue_types={\"label\": {}})  \n",
    "lab.report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check for additional types of issues with the same `Datalab`. Here, we would like to detect outliers and near duplicates which both utilize the features of the data.\n",
    "\n",
    "Notice that this second call to `find_issues()` updates the output of `report()`, we can see the existing label issues detected alongside the new issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab.find_issues(features=data[\"X\"], issue_types={\"outlier\": {}, \"near_duplicate\": {}})\n",
    "lab.report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Functionality 2**: Specifying nondefault arguments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also overwrite previously-executed checks for a type of issue. Here we re-run the detection of outliers, but specify that different non-default settings should be used (in this case, the number of neighbors `k` compared against to determine which datapoints are outliers). \n",
    "The results from this new detection will replace the original outlier detection results in the updated `Datalab`. You could similarly specify non-default settings for other issue types in the first call to `Datalab.find_issues()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab.find_issues(features=data[\"X\"], issue_types={\"outlier\": {\"k\": 30}})\n",
    "lab.report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also increase the verbosity of the `report` to see additional information about the data issues and control how many top-ranked examples are shown for each issue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab.report(num_examples=10, verbosity=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the number of flagged outlier issues has changed after specfying different settings to use for outlier detection."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Functionality 3**: Save and load Datalab objects"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A `Datalab` can be saved to a folder at a specified path. In a future Python process, this path can be used to load the `Datalab` from file back into memory. Your dataset is not saved as part of this process, so you'll need to save/load it separately to keep working with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"datalab-files\"\n",
    "lab.save(path, force=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can load a `Datalab` object we have on file and view the previously detected issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_lab = Datalab.load(path)\n",
    "new_lab.report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Functionality 4**: Adding a custom IssueManager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Datalab` detects pre-defined types of issues for you in one line of code: `find_issues()`. What if you want to check for other custom types of issues along with these pre-defined types, all within the same line of code?\n",
    "\n",
    "All issue types in `Datalab` are subclasses of cleanlab's `IssueManager` class.\n",
    "To register a custom issue type for use with `Datalab`, simply also make it a subclass of `IssueManager`.\n",
    "\n",
    "The necessary members to implement in the subclass are:\n",
    "\n",
    "- A class variable called `issue_name` that acts as a unique identifier for the type of issue.\n",
    "- An instance method called `find_issues` that:\n",
    "  - Computes a quality score for each example in the dataset (between 0-1), in terms of how *unlikely* it is to be an issue.\n",
    "  - Flags each example as an issue or not (may be based on thresholding the quality scores).\n",
    "  - Combine these in a dataframe that is assigned to an `issues` attribute of the `IssueManager`.\n",
    "  - Define a summary score for the overall quality of entire dataset, in terms of this type of issue. Set this score as part of the `summary` attribute of the `IssueManager`.\n",
    "  \n",
    "To demonstrate this, we create an arbitrary issue type that checks the divisibility of an example's index in the dataset by 13."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cleanlab.datalab.internal.issue_manager import IssueManager\n",
    "from cleanlab.datalab.internal.issue_manager_factory import register\n",
    "\n",
    "\n",
    "def scoring_function(idx: int, div: int = 13) -> float:\n",
    "    if idx == 0:\n",
    "        # Zero excluded from the divisibility check, gets the highest score\n",
    "        return 1\n",
    "    rem = idx % div\n",
    "    inv_scale = idx // div\n",
    "    if rem == 0:\n",
    "        return 0.5 * (1 - np.exp(-0.1*(inv_scale-1)))\n",
    "    else:\n",
    "        return 1 - 0.49 * (1 - np.exp(-inv_scale**0.5))*rem/div\n",
    "\n",
    "\n",
    "@register  # register this issue type for use with Datalab\n",
    "class SuperstitionIssueManager(IssueManager):\n",
    "    \"\"\"A custom issue manager that keeps track of issue indices that\n",
    "    are divisible by 13.\n",
    "    \"\"\"\n",
    "    description: str = \"Examples with indices that are divisible by 13 may be unlucky.\"  # Optional\n",
    "    issue_name: str = \"superstition\"\n",
    "\n",
    "    def find_issues(self, div=13, **_) -> None:\n",
    "        ids = self.datalab.issues.index.to_series()\n",
    "        issues_mask = ids.apply(lambda idx: idx % div == 0 and idx != 0)\n",
    "        scores = ids.apply(lambda idx: scoring_function(idx, div))\n",
    "        self.issues = pd.DataFrame(\n",
    "            {\n",
    "                f\"is_{self.issue_name}_issue\": issues_mask,\n",
    "                self.issue_score_key: scores,\n",
    "            },\n",
    "        )\n",
    "        summary_score = 1 - sum(issues_mask) / len(issues_mask)\n",
    "        self.summary = self.make_summary(score = summary_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once registered, this `IssueManager` will perform custom issue checks when `find_issues` is called on a `Datalab` instance.\n",
    "\n",
    "As our `Datalab` instance here already has results from the outlier and near duplicate checks, we perform the custom issue check separately."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lab.find_issues(issue_types={\"superstition\": {}})\n",
    "lab.report()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
